{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "import os\n",
    "import sys\n",
    "import panel as pn\n",
    "import numpy as np\n",
    "import pyvista as pv\n",
    "pn.extension('vtk')\n",
    "os.system('/usr/bin/Xvfb :99 -screen 0 1024x768x24 &')\n",
    "os.environ['DISPLAY'] = ':99'\n",
    "os.environ['PYVISTA_OFF_SCREEN'] = 'True'\n",
    "os.environ['PYVISTA_USE_PANEL'] = 'True'"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Experiment Manager"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "DIR = os.path.dirname(os.getcwd())\n",
    "ROOT = os.path.join(DIR, \"..\")\n",
    "sys.path.insert(0, ROOT)\n",
    "sys.path.insert(0, DIR)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch_points3d.visualization.experiment_manager import ExperimentManager"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "experiment_manager = ExperimentManager(DIR)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def update_pointcloud(file):\n",
    "    camera = [ (3.8930294513702393, 2.595715045928955, -2.18937349319458),\n",
    " (0.14671478027594276, 0.2490156926214695, -0.07320725917816162),\n",
    " (-0.4286445677280426, 0.8776587247848511, 0.21442723274230957)]\n",
    "    try:\n",
    "        experiment_manager.load_ply_file(file)\n",
    "    except:\n",
    "        print(\"Could not load corrupted file %s\" % file)\n",
    "        return\n",
    "    pl1 = pv.Plotter(notebook=True)\n",
    "    pl1.camera_position =  camera\n",
    "    point_cloud = pv.PolyData(experiment_manager.current_pointcloud['xyz'])\n",
    "    point_cloud['l'] = experiment_manager.current_pointcloud['l']\n",
    "    pl1.add_points(point_cloud)\n",
    "\n",
    "    pl2 = pv.Plotter(notebook=True)\n",
    "    pl2.camera_position =  camera\n",
    "    point_cloud = pv.PolyData(experiment_manager.current_pointcloud['xyz'])\n",
    "    point_cloud['p'] = experiment_manager.current_pointcloud['p']\n",
    "    pl2.add_points(point_cloud)\n",
    "    \n",
    "    pl3 = pv.Plotter(notebook=True)\n",
    "    pl3.camera_position =  camera\n",
    "    point_cloud = pv.PolyData(experiment_manager.current_pointcloud['xyz'])\n",
    "    point_cloud['e'] = experiment_manager.current_pointcloud['p'] == experiment_manager.current_pointcloud['l']\n",
    "    pl3.add_points(point_cloud)\n",
    "     \n",
    "    pan_left.object = pl1.ren_win\n",
    "    pan_right.object = pl2.ren_win\n",
    "    pan_err.object = pl3.ren_win\n",
    "    \n",
    "    \n",
    "\n",
    "def sync_pans(pans):\n",
    "    ref = pans[0]\n",
    "    for pan in pans[1:]:\n",
    "        ref.jslink(pan, camera='camera', bidirectional=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "selector_model_name = pn.widgets.Select(name='Model name', options=experiment_manager.model_name_wviz)\n",
    "selector_paths = pn.widgets.Select(name='Available runs', options=experiment_manager.get_model_wviz_paths(selector_model_name.value))\n",
    "selector_epoch =  pn.widgets.Select(name='Epoch',options = experiment_manager.from_paths_to_epoch(selector_paths.value))\n",
    "selector_split =  pn.widgets.Select(name='Split',options = experiment_manager.from_epoch_to_split(selector_epoch.value))\n",
    "selector_file =  pn.widgets.Select(name='Sample',options = experiment_manager.from_split_to_file(selector_split.value))\n",
    "\n",
    "pl_left = pv.Plotter(notebook=True)\n",
    "pan_left = pn.panel(pl_left.ren_win, sizing_mode='stretch_height', orientation_widget=True)\n",
    "\n",
    "pl_right = pv.Plotter(notebook=True)\n",
    "pan_right = pn.panel(pl_right.ren_win, sizing_mode='stretch_height',)\n",
    "\n",
    "pl_err = pv.Plotter(notebook=True)\n",
    "pan_err = pn.panel(pl_err.ren_win, sizing_mode='stretch_height',)\n",
    "\n",
    "\n",
    "pans = pn.Row(pn.Column(pan_left, \"### Ground truth\"),\n",
    "              pn.Column(pan_right, \"### Prediction\"),\n",
    "             pn.Column(pan_err, \"### Error\"))\n",
    "sync_pans([pan_left, pan_right, pan_err])\n",
    "col = pn.Column(selector_model_name, selector_paths, selector_epoch, selector_split, selector_file)\n",
    "\n",
    "dashboard = pn.Column(\n",
    "    '## Select file to display',\n",
    "    pn.Row(col, pans)\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def update_model_name(event):\n",
    "    old_value = selector_paths.value\n",
    "    selector_paths.options = experiment_manager.get_model_wviz_paths(selector_model_name.value)\n",
    "    if old_value == selector_epoch.value:\n",
    "        update_epoch(event)\n",
    "selector_model_name.param.watch(update_model_name, 'value')\n",
    "\n",
    "def update_model_path(event):\n",
    "    old_value = selector_epoch.value\n",
    "    selector_epoch.options = experiment_manager.from_paths_to_epoch(selector_paths.value)\n",
    "    if old_value == selector_epoch.value:\n",
    "        update_epoch(event)\n",
    "selector_paths.param.watch(update_model_path, 'value')\n",
    "\n",
    "def update_epoch(event):\n",
    "    old_value = selector_split.value\n",
    "    selector_split.options = experiment_manager.from_epoch_to_split(selector_epoch.value)\n",
    "    if old_value == selector_split.value:\n",
    "        update_split(event)\n",
    "selector_epoch.param.watch(update_epoch, 'value')\n",
    "\n",
    "def update_split(event):\n",
    "    old_value = selector_file.value\n",
    "    selector_file.options = experiment_manager.from_split_to_file(selector_split.value)\n",
    "    if old_value == selector_file.value:\n",
    "        update_file(event)\n",
    "selector_split.param.watch(update_split, 'value')\n",
    "\n",
    "def update_file(event):\n",
    "    update_pointcloud(selector_file.value)\n",
    "selector_file.param.watch(update_file, 'value')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dashboard"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
